package org.apache.solr.search.grouping;

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import org.apache.lucene.search.*;
import org.apache.lucene.search.grouping.AbstractAllGroupHeadsCollector;
import org.apache.lucene.search.grouping.TermAllGroupHeadsCollector;
import org.apache.lucene.util.OpenBitSet;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.search.*;
import org.apache.solr.search.QueryUtils;
import org.apache.solr.search.grouping.distributed.shardresultserializer.ShardResultTransformer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
 * Responsible for executing a search with a number of {@link Command}
 * instances. A typical search can have more then one {@link Command} instances.
 * 
 * @lucene.experimental
 */
public class CommandHandler {

	public static class Builder {

		private SolrIndexSearcher.QueryCommand queryCommand;
		private List<Command> commands = new ArrayList<Command>();
		private SolrIndexSearcher searcher;
		private boolean needDocSet = false;
		private boolean truncateGroups = false;
		private boolean includeHitCount = false;

		public Builder setQueryCommand(
				SolrIndexSearcher.QueryCommand queryCommand) {
			this.queryCommand = queryCommand;
			this.needDocSet = (queryCommand.getFlags() & SolrIndexSearcher.GET_DOCSET) != 0;
			return this;
		}

		public Builder addCommandField(Command commandField) {
			commands.add(commandField);
			return this;
		}

		public Builder setSearcher(SolrIndexSearcher searcher) {
			this.searcher = searcher;
			return this;
		}

		/**
		 * Sets whether to compute a {@link DocSet}. May override the value set
		 * by
		 * {@link #setQueryCommand(org.apache.solr.search.SolrIndexSearcher.QueryCommand)}
		 * .
		 * 
		 * @param needDocSet
		 *            Whether to compute a {@link DocSet}
		 * @return this
		 */
		public Builder setNeedDocSet(boolean needDocSet) {
			this.needDocSet = needDocSet;
			return this;
		}

		public Builder setIncludeHitCount(boolean includeHitCount) {
			this.includeHitCount = includeHitCount;
			return this;
		}

		public Builder setTruncateGroups(boolean truncateGroups) {
			this.truncateGroups = truncateGroups;
			return this;
		}

		public CommandHandler build() {
			if (queryCommand == null || searcher == null) {
				throw new IllegalStateException("All fields must be set");
			}

			return new CommandHandler(queryCommand, commands, searcher,
					needDocSet, truncateGroups, includeHitCount);
		}

	}

	private final static Logger logger = LoggerFactory
			.getLogger(CommandHandler.class);

	private final SolrIndexSearcher.QueryCommand queryCommand;
	private final List<Command> commands;
	private final SolrIndexSearcher searcher;
	private final boolean needDocset;
	private final boolean truncateGroups;
	private final boolean includeHitCount;
	private boolean partialResults = false;
	private int totalHitCount;

	private DocSet docSet;

	private CommandHandler(SolrIndexSearcher.QueryCommand queryCommand,
			List<Command> commands, SolrIndexSearcher searcher,
			boolean needDocset, boolean truncateGroups, boolean includeHitCount) {
		this.queryCommand = queryCommand;
		this.commands = commands;
		this.searcher = searcher;
		this.needDocset = needDocset;
		this.truncateGroups = truncateGroups;
		this.includeHitCount = includeHitCount;
	}

	@SuppressWarnings("unchecked")
	public void execute() throws IOException {
		final int nrOfCommands = commands.size();
		List<Collector> collectors = new ArrayList<Collector>(nrOfCommands);
		for (Command command : commands) {
			collectors.addAll(command.create());
		}

		SolrIndexSearcher.ProcessedFilter pf = searcher.getProcessedFilter(
				queryCommand.getFilter(), queryCommand.getFilterList());
		Filter luceneFilter = pf.filter;
		Query query = QueryUtils.makeQueryable(queryCommand.getQuery());

		if (truncateGroups) {
			docSet = computeGroupedDocSet(query, luceneFilter, collectors);
		} else if (needDocset) {
			docSet = computeDocSet(query, luceneFilter, collectors);
		} else if (!collectors.isEmpty()) {
			searchWithTimeLimiter(query, luceneFilter,
					MultiCollector.wrap(collectors
							.toArray(new Collector[nrOfCommands])));
		} else {
			searchWithTimeLimiter(query, luceneFilter, null);
		}
	}

	private DocSet computeGroupedDocSet(Query query, Filter luceneFilter,
			List<Collector> collectors) throws IOException {
		Command firstCommand = commands.get(0);
		AbstractAllGroupHeadsCollector termAllGroupHeadsCollector = TermAllGroupHeadsCollector
				.create(firstCommand.getKey(),
						firstCommand.getSortWithinGroup());
		if (collectors.isEmpty()) {
			searchWithTimeLimiter(query, luceneFilter,
					termAllGroupHeadsCollector);
		} else {
			collectors.add(termAllGroupHeadsCollector);
			searchWithTimeLimiter(query, luceneFilter,
					MultiCollector.wrap(collectors
							.toArray(new Collector[collectors.size()])));
		}

		int maxDoc = searcher.maxDoc();
		long[] bits = termAllGroupHeadsCollector.retrieveGroupHeads(maxDoc)
				.getBits();
		return new BitDocSet(new OpenBitSet(bits, bits.length));
	}

	private DocSet computeDocSet(Query query, Filter luceneFilter,
			List<Collector> collectors) throws IOException {
		int maxDoc = searcher.maxDoc();
		DocSetCollector docSetCollector;
		if (collectors.isEmpty()) {
			docSetCollector = new DocSetCollector(maxDoc >> 6, maxDoc);
		} else {
			Collector wrappedCollectors = MultiCollector.wrap(collectors
					.toArray(new Collector[collectors.size()]));
			docSetCollector = new DocSetDelegateCollector(maxDoc >> 6, maxDoc,
					wrappedCollectors);
		}
		searchWithTimeLimiter(query, luceneFilter, docSetCollector);
		return docSetCollector.getDocSet();
	}

	@SuppressWarnings("unchecked")
	public NamedList processResult(SolrIndexSearcher.QueryResult queryResult,
			ShardResultTransformer transformer) throws IOException {
		if (docSet != null) {
			queryResult.setDocSet(docSet);
		}
		queryResult.setPartialResults(partialResults);
		return transformer.transform(commands);
	}

	/**
	 * Invokes search with the specified filter and collector. If a time limit
	 * has been specified then wrap the collector in the TimeLimitingCollector
	 */
	private void searchWithTimeLimiter(final Query query,
			final Filter luceneFilter, Collector collector) throws IOException {
		if (queryCommand.getTimeAllowed() > 0) {
			collector = new TimeLimitingCollector(collector,
					TimeLimitingCollector.getGlobalCounter(),
					queryCommand.getTimeAllowed());
		}

		TotalHitCountCollector hitCountCollector = new TotalHitCountCollector();
		if (includeHitCount) {
			collector = MultiCollector.wrap(collector, hitCountCollector);
		}

		try {
			searcher.search(query, luceneFilter, collector);
		} catch (TimeLimitingCollector.TimeExceededException x) {
			partialResults = true;
			logger.warn("Query: " + query + "; " + x.getMessage());
		}

		if (includeHitCount) {
			totalHitCount = hitCountCollector.getTotalHits();
		}
	}

	public int getTotalHitCount() {
		return totalHitCount;
	}
}
